Mulgrad ================= 计算逐元素乘法 (Mul) 操作的梯度。该算子是 Mul 算子的反向传播(backward pass)部分。梯度的计算遵循链式法则。 .. math:: \text{dx0}_i = \text{dy}_i \times \text{Input1}_i \text{dx1}_i = \text{dy}_i \times \text{Input0}_i 其中 dx0 和 dx1 分别是损失函数对前向输入 Input0 和 Input1 的梯度。 Gradmul1L版本专门用于 `x1` 张量维度大于或等于 `x2` 张量的广播场景。Gradmul2l版本专门用于 `x2` 张量维度大于或等于 `x1` 张量的广播场景。 输入: - **dy** - 来自后一层的上游梯度张量。 - **x1** - 前向传播时的第一个输入张量(被除数)。 - **x2** - 前向传播时的第二个输入张量(除数)。 - **large_shape** - `x1` 和 `x2` 中维度较大的张量的形状。 - **small_shape** - `x1` 和 `x2` 中维度较小的张量的形状。 - **out_shape** - 输出张量 `dx1` 和 `dx2` 的形状。 - **ndims** - 张量的维度数。 - **large_strides** - 维度较大张量的步长信息。 - **small_strides** - 维度较小张量的步长信息。 - **out_strides** - 输出张量的步长信息。 - **large_multiples** - 维度较大张量的广播倍数。 - **small_multiples** - 维度较小张量的广播倍数。 - **tile_data0** - 临时工作空间地址。 - **tile_data1** - 临时工作空间地址。 - **indices** - 用于广播计算的临时索引空间地址。 - **core_mask** - 核掩码。 输出: - **dx1** - 写入计算出的对 `x1` 的梯度。 - **dx2** - 写入计算出的对 `x2` 的梯度。 支持平台: ``FT78NE`` ``MT7004`` .. note:: - FT78NE 支持fp32 - MT7004 支持fp16, fp32 **共享存储版本:** .. c:function:: void fp_gradmul_s(float* dy, float* x1, float* x2, int* large_shape, int* small_shape, int* out_shape, int ndims, int* large_strides, int* small_strides, int* out_strides, int* large_multiples, int* small_multiples, float* dx1, float* dx2, float* tile_data0, float* tile_data1, int* indices, int core_mask) .. c:function:: void hp_gradmul_s(half* dy, half* x1, half* x2, int* large_shape, int* small_shape, int* out_shape, int ndims, int* large_strides, int* small_strides, int* out_strides, int* large_multiples, int* small_multiples, half* dx1, half* dx2, half* tile_data0, half* tile_data1, int* indices, int core_mask) .. c:function:: void fp_gradmul1l_s(float* dy, float* x1, float* x2, int* large_shape, int* small_shape, int* out_shape, int ndims, int* large_strides, int* small_strides, int* out_strides, int* large_multiples, int* small_multiples, float* dx1, float* dx2, float* tile_data0, float* tile_data1, int* indices, int core_mask) .. c:function:: void hp_gradmul1l_s(half* dy, half* x1, half* x2, int* large_shape, int* small_shape, int* out_shape, int ndims, int* large_strides, int* small_strides, int* out_strides, int* large_multiples, int* small_multiples, half* dx1, half* dx2, half* tile_data0, half* tile_data1, int* indices, int core_mask) .. c:function:: void fp_gradmul2l_s(float* dy, float* x1, float* x2, int* large_shape, int* small_shape, int* out_shape, int ndims, int* large_strides, int* small_strides, int* out_strides, int* large_multiples, int* small_multiples, float* dx1, float* dx2, float* tile_data0, float* tile_data1, int* indices, int core_mask) .. c:function:: void hp_gradmul2l_s(half* dy, half* x1, half* x2, int* large_shape, int* small_shape, int* out_shape, int ndims, int* large_strides, int* small_strides, int* out_strides, int* large_multiples, int* small_multiples, half* dx1, half* dx2, half* tile_data0, half* tile_data1, int* indices, int core_mask) **C调用示例:** .. code-block:: c :linenos: :emphasize-lines: 38 //FT78NE示例 #include #include int main(int argc, char* argv[]) { float *dy = (float *)0xA1000000; float *dx1 = (float *)0xA2000000; float *dx2 = (float *)0xA3000000; float *x1_data = (float *)0xA4000000; float *x2_data = (float *)0xA5000000; float *tile_data0 = (float *)0xA6000000; float *tile_data1 = (float *)0xA7000000; long long ndims = 4; long long dy_size; long long x1_size; long long x2_size; int *large_strides = (int *)0xAB000000; int *small_strides = (int *)0xAB100000; int *out_strides = (int *)0xAB200000; int *large_multiples = (int *)0xAB300000; int *small_multiples = (int *)0xAB400000; int *indices = (int *)0xAB500000; int *large_shape = (int *)0xAB600000; int *small_shape = (int *)0xAB700000; int *out_shape = (int *)0xAB800000; large_shape[0] = 12; large_shape[1] = 14; large_shape[2] = 3; large_shape[3] = 5; small_shape[0] = 12; small_shape[1] = 14; small_shape[2] = 3; small_shape[3] = 5; out_shape[0] = 12; out_shape[1] = 14; out_shape[2] = 3; out_shape[3] = 5; int core_mask = 0xff; dy_size = out_shape[0] * out_shape[1] * out_shape[2] * out_shape[3]; x1_size = large_shape[0] * large_shape[1] * large_shape[2] * large_shape[3]; x2_size = small_shape[0] * small_shape[1] * small_shape[2] * small_shape[3]; fp_gradmul_s(dy, x1, x2, large_shape, small_shape, out_shape, ndims, large_strides, small_strides, out_strides, large_multiples, small_multiples, dx1, dx2, tile_data0, tile_data1, indices, core_mask); return 0; } **私有存储版本:** .. c:function:: void fp_gradmul_p(float* dy, float* x1, float* x2, int* large_shape, int* small_shape, int* out_shape, int ndims, int* large_strides, int* small_strides, int* out_strides, int* large_multiples, int* small_multiples, float* dx1, float* dx2, float* tile_data0, float* tile_data1, int* indices) .. c:function:: void hp_gradmul_p(half* dy, half* x1, half* x2, int* large_shape, int* small_shape, int* out_shape, int ndims, int* large_strides, int* small_strides, int* out_strides, int* large_multiples, int* small_multiples, half* dx1, half* dx2, half* tile_data0, half* tile_data1) .. c:function:: void fp_gradmul1l_p(float* dy, float* x1, float* x2, int* large_shape, int* small_shape, int* out_shape, int ndims, int* large_strides, int* small_strides, int* out_strides, int* large_multiples, int* small_multiples, float* dx1, float* dx2, float* tile_data0, float* tile_data1, int* indices) .. c:function:: void hp_gradmul1l_p(half* dy, half* x1, half* x2, int* large_shape, int* small_shape, int* out_shape, int ndims, int* large_strides, int* small_strides, int* out_strides, int* large_multiples, int* small_multiples, half* dx1, half* dx2, half* tile_data0, half* tile_data1, int* indices) .. c:function:: void fp_gradmul2l_p(float* dy, float* x1, float* x2, int* large_shape, int* small_shape, int* out_shape, int ndims, int* large_strides, int* small_strides, int* out_strides, int* large_multiples, int* small_multiples, float* dx1, float* dx2, float* tile_data0, float* tile_data1, int* indices) .. c:function:: void hp_gradmul2l_p(half* dy, half* x1, half* x2, int* large_shape, int* small_shape, int* out_shape, int ndims, int* large_strides, int* small_strides, int* out_strides, int* large_multiples, int* small_multiples, half* dx1, half* dx2, half* tile_data0, half* tile_data1, int* indices) **C调用示例:** .. code-block:: c :linenos: :emphasize-lines: 36 //FT78NE示例 #include #include int main(int argc, char* argv[]) { float *dy = (float *)0x10000000; float *dx1 = (float *)0x12000000; float *dx2 = (float *)0x13000000; float *x1_data = (float *)0x14000000; float *x2_data = (float *)0x15000000; float *tile_data0 = (float *)0x16000000; float *tile_data1 = (float *)0x17000000; long long ndims = 4; long long dy_size; long long x1_size; long long x2_size; int *large_strides = (int *)0x1B000000; int *small_strides = (int *)0x1B100000; int *out_strides = (int *)0x1B200000; int *large_multiples = (int *)0x1B300000; int *small_multiples = (int *)0x1B400000; int *indices = (int *)0x1B500000; int *large_shape = (int *)0x1B600000; int *small_shape = (int *)0x1B700000; int *out_shape = (int *)0x1B800000; large_shape[0] = 12; large_shape[1] = 14; large_shape[2] = 3; large_shape[3] = 5; small_shape[0] = 12; small_shape[1] = 14; small_shape[2] = 3; small_shape[3] = 5; out_shape[0] = 12; out_shape[1] = 14; out_shape[2] = 3; out_shape[3] = 5; dy_size = out_shape[0] * out_shape[1] * out_shape[2] * out_shape[3]; x1_size = large_shape[0] * large_shape[1] * large_shape[2] * large_shape[3]; x2_size = small_shape[0] * small_shape[1] * small_shape[2] * small_shape[3]; fp_gradmul_p(dy, x1, x2, large_shape, small_shape, out_shape, ndims, large_strides, small_strides, out_strides, large_multiples, small_multiples, dx1, dx2, tile_data0, tile_data1, indices); return 0; }